{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset.ade as ade\n",
    "from dataset.ade import AdeSegmentationIncremental\n",
    "import dataset.transform as transforms\n",
    "import torchvision as tv\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = ade.AdeSegmentation('/vandal/datasets', train=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done 0 / 20210\n",
      "Done 1000 / 20210\n",
      "Done 2000 / 20210\n",
      "Done 3000 / 20210\n",
      "Done 4000 / 20210\n",
      "Done 5000 / 20210\n",
      "Done 6000 / 20210\n",
      "Done 7000 / 20210\n",
      "Done 8000 / 20210\n",
      "Done 9000 / 20210\n",
      "Done 10000 / 20210\n",
      "Done 11000 / 20210\n",
      "Done 12000 / 20210\n",
      "Done 13000 / 20210\n",
      "Done 14000 / 20210\n",
      "Done 15000 / 20210\n",
      "Done 16000 / 20210\n",
      "Done 17000 / 20210\n",
      "Done 18000 / 20210\n",
      "Done 19000 / 20210\n",
      "Done 20000 / 20210\n"
     ]
    }
   ],
   "source": [
    "def group_images(dataset, labels):\n",
    "    # Group images based on the label in LABELS\n",
    "    idxs = [[] for lab in labels]\n",
    "\n",
    "    labels_cum = labels + [0, 255]\n",
    "    for i in range(len(dataset)):\n",
    "        cls = np.unique(np.array(dataset[i][1]))\n",
    "        if all(x in labels_cum for x in cls):\n",
    "            for x in cls:\n",
    "                if x in labels:\n",
    "                    idxs[x].append(i)\n",
    "        else:\n",
    "            print(f\"Not all in labels_cum: {cls}\")\n",
    "        if i % 1000 == 0:\n",
    "            print(f\"Done {i} / {len(dataset)}\")\n",
    "    return idxs\n",
    "\n",
    "idxs = group_images(data, list(range(151)))\n",
    "# imgs_to_cls map all the images to all the classes contained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[20208, 11588, 6042, 8240, 9314, 6674, 6579, 3990, 1825, 4687, 2421, 2851, 3061, 5069, 1791, 3276, 4266, 1690, 3995, 2148, 3258, 3161, 702, 3295, 1191, 1404, 667, 650, 1229, 1132, 479, 1172, 521, 1385, 698, 781, 378, 3083, 404, 803, 1447, 411, 1437, 799, 2646, 549, 367, 310, 1181, 287, 468, 402, 130, 561, 880, 84, 172, 198, 930, 139, 564, 320, 261, 275, 335, 792, 395, 1308, 1112, 266, 658, 330, 531, 369, 144, 265, 324, 304, 170, 68, 65, 247, 492, 2508, 440, 135, 583, 532, 1987, 71, 618, 133, 83, 178, 992, 181, 451, 41, 307, 965, 117, 354, 108, 556, 52, 99, 57, 66, 77, 340, 66, 465, 50, 622, 80, 59, 530, 282, 73, 213, 136, 116, 266, 58, 417, 319, 1191, 97, 347, 52, 246, 108, 201, 284, 268, 1020, 1281, 528, 452, 671, 397, 92, 228, 556, 217, 206, 126, 178, 504, 773, 421]\n"
     ]
    }
   ],
   "source": [
    "len_idxs = [len(idxs[i]) for i in range(151)]\n",
    "print(len_idxs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20196\n"
     ]
    }
   ],
   "source": [
    "class_members = {}\n",
    "for i,l in enumerate(idxs):\n",
    "    if i==0:\n",
    "        continue\n",
    "    class_members[i]=len(l)\n",
    "\n",
    "#random.seed(42)\n",
    "a = np.arange(151) \n",
    "a=a[1:]\n",
    "a=np.random.permutation(a)\n",
    "ordered_classes = class_members.items()\n",
    "\n",
    "# map the image to the class contained, the inverse of idxs.\n",
    "imgs_to_cls = {}\n",
    "for j,v in ordered_classes:\n",
    "        for i in list(idxs[j]):\n",
    "            if not i in imgs_to_cls.keys():\n",
    "                imgs_to_cls[i]=[]\n",
    "            imgs_to_cls[i].append(j)\n",
    "print(len(imgs_to_cls))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check if all image has at least 1 object labeled (not all bkg)\n",
    "for i in imgs_to_cls.keys():\n",
    "    if len(imgs_to_cls[i])<1:\n",
    "        print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: [3, 5, 6, 13, 14, 20, 23, 25, 27, 31, 35, 39, 43, 44, 46, 47, 49, 50, 51, 52, 54, 56, 57, 58, 63, 65, 68, 69, 79, 86, 88, 90, 95, 98, 105, 108, 109, 113, 114, 118, 120, 125, 126, 131, 132, 136, 138, 146, 149, 150], 1: [7, 8, 16, 17, 18, 19, 21, 28, 29, 30, 34, 36, 38, 61, 66, 67, 71, 73, 75, 76, 78, 81, 82, 83, 84, 85, 87, 89, 91, 94, 97, 99, 103, 104, 106, 107, 115, 117, 122, 124, 127, 128, 133, 134, 140, 141, 144, 145, 147, 148], 2: [1, 2, 4, 9, 10, 11, 12, 15, 22, 24, 26, 32, 33, 37, 40, 41, 42, 45, 48, 53, 55, 59, 60, 62, 64, 70, 72, 74, 77, 80, 92, 93, 96, 100, 101, 102, 110, 111, 112, 116, 119, 121, 123, 129, 130, 135, 137, 139, 142, 143]}\n",
      "{0: 50, 1: 50, 2: 50}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# make the order for the dataset!\n",
    "# map_class_to_set maps the class to the corresponding step\n",
    "\n",
    "base_len = 50 # first training step: # of images\n",
    "inc_len = 50  # following training step (assume same dimension)\n",
    "sets = 2      # Num of incremental steps. E.g. the 50-50-50 has 2 inc. steps\n",
    "\n",
    "order = np.arange(1,151)\n",
    "np.random.shuffle(order) # make a random order\n",
    "\n",
    "map_set_to_class = {}\n",
    "map_class_to_set = {}\n",
    "\n",
    "map_set_to_class[0] = order[0:base_len]\n",
    "map_class_to_set.update({x:0 for x in map_set_to_class[0]})\n",
    "offset = base_len\n",
    "for i in range(1, sets+1):\n",
    "    map_set_to_class[i] = order[offset : offset + inc_len]\n",
    "    offset = offset + inc_len\n",
    "    map_class_to_set.update({x:i for x in map_set_to_class[i]})\n",
    "\n",
    "print({i:sorted(list(x)) for i,x in map_set_to_class.items()})\n",
    "print({i:len(x) for i,x in map_set_to_class.items()})\n",
    "\n",
    "map_class_to_set[1]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 0.3566620642043493 0.6568864342423196 11588 4133\n",
      "2 0.29344587884806356 0.6607083747103608 6042 1773\n",
      "3 0.3963592233009709 0.603640776699029 8240 3266\n",
      "4 0.35129911960489585 0.6398969293536612 9314 3272\n",
      "5 0.37413844770752175 0.6258615522924782 6674 2497\n",
      "6 0.3731570147438821 0.6268429852561179 6579 2455\n",
      "7 0.43107769423558895 0.7172932330827068 3990 1720\n",
      "8 0.3254794520547945 0.6202739726027398 1825 594\n",
      "9 0.3522509067633881 0.6460422445060807 4687 1651\n",
      "10 0.3985956216439488 0.6600578273440727 2421 965\n",
      "11 0.3444405471764293 0.6545071904594879 2851 982\n",
      "12 0.31950343025155176 0.7624959163672003 3061 978\n",
      "13 0.34819491023870586 0.6518050897612941 5069 1765\n",
      "14 0.47068676716917923 0.5293132328308208 1791 843\n",
      "15 0.35836385836385837 0.6504884004884005 3276 1174\n",
      "16 0.2817627754336615 0.597515236755743 4266 1202\n",
      "17 0.4325443786982249 0.6230769230769231 1690 731\n",
      "18 0.3406758448060075 0.672090112640801 3995 1361\n",
      "19 0.30353817504655495 0.5963687150837989 2148 652\n",
      "20 0.38275015346838553 0.6172498465316145 3258 1247\n",
      "21 0.4435305283138247 0.7431192660550459 3161 1402\n",
      "23 0.4267071320182094 0.5732928679817906 3295 1406\n",
      "24 0.2753988245172124 0.48698572628043657 1191 328\n",
      "25 0.396011396011396 0.603988603988604 1404 556\n",
      "28 0.3954434499593165 0.6737184703010577 1229 486\n",
      "29 0.29858657243816256 0.5627208480565371 1132 338\n",
      "33 0.4332129963898917 0.7350180505415163 1385 600\n",
      "37 0.3065196237431074 0.5776840739539411 3083 945\n",
      "39 0.43462017434620176 0.5653798256537983 803 349\n",
      "40 0.2833448514167243 0.5238424326192122 1447 410\n",
      "42 0.3848295059151009 0.6840640222686152 1437 553\n",
      "43 0.4292866082603254 0.5707133917396746 799 343\n",
      "44 0.30687830687830686 0.6931216931216931 2646 812\n",
      "48 0.40474174428450466 0.7874682472480948 1181 478\n",
      "51 0.3880597014925373 0.6119402985074627 402 156\n",
      "54 0.4375 0.5625 880 385\n",
      "55 0.9047619047619048 1.0 84 76\n",
      "58 0.45268817204301076 0.5473118279569893 930 421\n",
      "67 0.3195718654434251 0.6368501529051989 1308 418\n",
      "68 0.4847122302158273 0.5152877697841727 1112 539\n",
      "83 0.29784688995215314 0.6495215311004785 2508 747\n",
      "86 0.4939965694682676 0.5060034305317325 583 288\n",
      "88 0.31806743834927026 0.6819325616507297 1987 632\n",
      "89 0.9859154929577465 1.0 71 70\n",
      "91 0.5338345864661654 1.0 133 71\n",
      "92 0.9879518072289156 0.9879518072289156 83 82\n",
      "99 0.39067357512953366 0.7274611398963731 965 377\n",
      "105 0.9191919191919192 0.08080808080808081 99 91\n",
      "107 0.9696969696969697 0.9848484848484849 66 64\n",
      "108 0.974025974025974 0.025974025974025976 77 75\n",
      "110 0.9696969696969697 1.0 66 64\n",
      "113 0.4517684887459807 0.5482315112540193 622 281\n",
      "114 0.9125 0.0875 80 73\n",
      "123 0.9827586206896551 1.0 58 57\n",
      "125 0.39184952978056425 0.6081504702194358 319 125\n",
      "126 0.3879093198992443 0.6120906801007557 1191 462\n",
      "127 0.9484536082474226 0.979381443298969 97 92\n",
      "129 0.9807692307692307 0.9807692307692307 52 51\n",
      "131 0.9166666666666666 0.08333333333333333 108 99\n",
      "135 0.3911764705882353 0.7039215686274509 1020 399\n",
      "136 0.39968774395003903 0.600312256049961 1281 512\n",
      "139 0.43517138599105815 0.849478390461997 671 292\n",
      "141 0.9456521739130435 1.0 92 87\n",
      "148 0.49404761904761907 0.7678571428571429 504 249\n",
      "149 0.45666235446313064 0.5433376455368694 773 353\n",
      "7646\n",
      "6149\n",
      "6401\n",
      "20196\n"
     ]
    }
   ],
   "source": [
    "# This is the code who actually make the split.\n",
    "\n",
    "all_added = set()\n",
    "\n",
    "imgs={}\n",
    "others = {}\n",
    "every=set()\n",
    "added = [0 for i in range(151)]\n",
    "for i in range(sets+1):\n",
    "    others[i]=set()\n",
    "    imgs[i]=set()\n",
    "\n",
    "# This is the most imporant function. \n",
    "# It computes the score of the image to be assinged to a class\n",
    "def score(ratio,imgs,expected,minval=False,w1=120,w2=1000,w3=1 ):\n",
    "    return (1.-ratio)*w1+w3*minval\n",
    "\n",
    "for i in imgs_to_cls.keys(): # loop for every image in the dataset\n",
    "    # img_counts = num of images per class for the class in the actual image  \n",
    "    img_counts = [class_members[j] for j in imgs_to_cls[i]] \n",
    "    # ratio = images assigned to the class / total number of images for the class  \n",
    "    ratios = [(added[j]+0.0)/class_members[j] for j in imgs_to_cls[i]]\n",
    "    added_counts = [added[j] for j in imgs_to_cls[i]]\n",
    "    # assignments = set of each class in the image\n",
    "    assignments = [map_class_to_set[j] for j in imgs_to_cls[i]]\n",
    "    \n",
    "    scores = [score(ratios[c],added_counts[c],class_members[j],1./(img_counts[c]/sum(img_counts))) for c,j in enumerate(imgs_to_cls[i])]\n",
    "    \n",
    "    # take the highest scorer class\n",
    "    cl = scores.index(max(scores))\n",
    "    # take the set of the higher scorer class\n",
    "    a=assignments[cl]\n",
    "    \n",
    "    # add the image to the step images\n",
    "    imgs[a].update([i])\n",
    "    for j,ac in enumerate(assignments):\n",
    "        if ac==a:\n",
    "            # increment the number of images assigned to the classes in the current image contained in the same step.\n",
    "            added[imgs_to_cls[i][j]] += 1\n",
    "            \n",
    "            \n",
    "for i in range(1,len(added)):        \n",
    "        assignment = map_class_to_set[i]\n",
    "        ratios = [len(set(idxs[i]).intersection(imgs[j]))/class_members[i] for j in range(0,sets+1)]\n",
    "        if ratios[assignment]<0.5 or (ratios[assignment]<1. and added[i]<100):\n",
    "            print(i,ratios[assignment], sum(ratios[1:]), class_members[i], len(set(idxs[i]).intersection(imgs[assignment])))\n",
    "s=0\n",
    "for i in range(sets+1):\n",
    "    print(len(imgs[i]))\n",
    "    s+=len(imgs[i])\n",
    "    \n",
    "print(s)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def class_occurrencies(data, idxs, n_cl=151):\n",
    "    classes = np.zeros(n_cl)\n",
    "    for i in idxs:\n",
    "        for cls in imgs_to_cls[i]:\n",
    "            classes[cls] += 1\n",
    "    return classes\n",
    "\n",
    "# for i in range(sets+1):\n",
    "#     print(f\"{i} : {len(imgs[i])}\")\n",
    "#     for j in range(sets+1):\n",
    "#         print(f\"{i} : {j}={map_set_to_class[j]}-{class_occurrencies(data, imgs[i])[map_set_to_class[j]]}\")\n",
    "#     print(\"\\n\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: array([125,   6,  54,  56,  57,  63, 150,  90,  58,  51,  68, 136,  31,\n",
       "         14,   3, 118,  20,  86,  50,  13, 132,  44,  69,  46,  23, 108,\n",
       "         43, 138, 126, 149,  88,  27,  47, 109,   5,  65,  49, 105,  25,\n",
       "        113,  98,  95,  35, 146,  39, 120, 131, 114,  79,  52]),\n",
       " 1: array([ 28,  30,  99,  83, 128, 148,   7,  16,  29,  75,  34,  71, 144,\n",
       "         67,  84, 147, 122,  78,  81, 134, 117,  94,  82,  36,  73, 140,\n",
       "         21,  38,  76, 141, 106, 107, 133,  18,  97,  61, 115, 127,  87,\n",
       "        124,   8, 145,  66, 103,  85,  19,  91,  89, 104,  17]),\n",
       " 2: array([ 32,  22,  48, 112,  40,  59,  60,  10, 121, 111,   1, 116, 129,\n",
       "         37,  53,  70,  77,  74,  24, 100,  33,   9,  12,   2,  11,  80,\n",
       "         15, 102, 143, 101,  72, 137, 110,  93,  41, 135,  55,  26,  92,\n",
       "        123,  96, 142,  62,  45,  42, 139, 130,  64, 119,   4])}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "map_set_to_class_id = {}\n",
    "for i in range(sets+1):\n",
    "    map_set_to_class_id[i] = np.array(list(range(151)))[map_set_to_class[i]]\n",
    "map_set_to_class_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7646\n",
      "6149\n",
      "6401\n"
     ]
    }
   ],
   "source": [
    "path = '../split/'\n",
    "os.makedirs(path, exist_ok=True)\n",
    "with open(path + 'order.txt', 'w') as f:\n",
    "    f.write(str({i:sorted(list(x)) for i,x in map_set_to_class_id.items()}))\n",
    "for i in range(sets+1):\n",
    "    np.save(path + 'train-'+str(i)+'.npy',  np.array(list(imgs[i])))\n",
    "    print(len(imgs[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
